{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "from sklearn.metrics import matthews_corrcoef, confusion_matrix\n",
    "\n",
    "from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,\n",
    "                              TensorDataset)\n",
    "from torch.utils.data.distributed import DistributedSampler\n",
    "from torch.nn import CrossEntropyLoss, MSELoss\n",
    "\n",
    "from tools import *\n",
    "from multiprocessing import Pool, cpu_count\n",
    "import convert_examples_to_features\n",
    "\n",
    "from tqdm import tqdm_notebook, trange\n",
    "import os\n",
    "from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM, BertForSequenceClassification\n",
    "from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule\n",
    "\n",
    "# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows\n",
    "import logging\n",
    "logging.basicConfig(level=logging.INFO)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The input data dir. Should contain the .tsv files (or other data files) for the task.\n",
    "DATA_DIR = \"data/\"\n",
    "\n",
    "# Bert pre-trained model selected in the list: bert-base-uncased, \n",
    "# bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased,\n",
    "# bert-base-multilingual-cased, bert-base-chinese.\n",
    "BERT_MODEL = 'yelp.tar.gz'\n",
    "\n",
    "# The name of the task to train.I'm going to name this 'yelp'.\n",
    "TASK_NAME = 'yelp'\n",
    "\n",
    "# The output directory where the fine-tuned model and checkpoints will be written.\n",
    "OUTPUT_DIR = f'outputs/{TASK_NAME}/'\n",
    "\n",
    "# The directory where the evaluation reports will be written to.\n",
    "REPORTS_DIR = f'reports/{TASK_NAME}_evaluation_reports/'\n",
    "\n",
    "# This is where BERT will look for pre-trained models to load parameters from.\n",
    "CACHE_DIR = 'cache/'\n",
    "\n",
    "# The maximum total input sequence length after WordPiece tokenization.\n",
    "# Sequences longer than this will be truncated, and sequences shorter than this will be padded.\n",
    "MAX_SEQ_LENGTH = 128\n",
    "\n",
    "TRAIN_BATCH_SIZE = 24\n",
    "EVAL_BATCH_SIZE = 8\n",
    "LEARNING_RATE = 2e-5\n",
    "NUM_TRAIN_EPOCHS = 1\n",
    "RANDOM_SEED = 42\n",
    "GRADIENT_ACCUMULATION_STEPS = 1\n",
    "WARMUP_PROPORTION = 0.1\n",
    "OUTPUT_MODE = 'classification'\n",
    "\n",
    "CONFIG_NAME = \"config.json\"\n",
    "WEIGHTS_NAME = \"pytorch_model.bin\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.exists(REPORTS_DIR) and os.listdir(REPORTS_DIR):\n",
    "        REPORTS_DIR += f'/report_{len(os.listdir(REPORTS_DIR))}'\n",
    "        os.makedirs(REPORTS_DIR)\n",
    "if not os.path.exists(REPORTS_DIR):\n",
    "    os.makedirs(REPORTS_DIR)\n",
    "    REPORTS_DIR += f'/report_{len(os.listdir(REPORTS_DIR))}'\n",
    "    os.makedirs(REPORTS_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_eval_report(task_name, labels, preds):\n",
    "    mcc = matthews_corrcoef(labels, preds)\n",
    "    tn, fp, fn, tp = confusion_matrix(labels, preds).ravel()\n",
    "    return {\n",
    "        \"task\": task_name,\n",
    "        \"mcc\": mcc,\n",
    "        \"tp\": tp,\n",
    "        \"tn\": tn,\n",
    "        \"fp\": fp,\n",
    "        \"fn\": fn\n",
    "    }\n",
    "\n",
    "def compute_metrics(task_name, labels, preds):\n",
    "    assert len(preds) == len(labels)\n",
    "    return get_eval_report(task_name, labels, preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:pytorch_pretrained_bert.tokenization:loading vocabulary file outputs/yelp/vocab.txt\n"
     ]
    }
   ],
   "source": [
    "# Load pre-trained model tokenizer (vocabulary)\n",
    "tokenizer = BertTokenizer.from_pretrained(OUTPUT_DIR + 'vocab.txt', do_lower_case=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "processor = BinaryClassificationProcessor()\n",
    "eval_examples = processor.get_dev_examples(DATA_DIR)\n",
    "label_list = processor.get_labels() # [0, 1] for binary classification\n",
    "num_labels = len(label_list)\n",
    "eval_examples_len = len(eval_examples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_map = {label: i for i, label in enumerate(label_list)}\n",
    "eval_examples_for_processing = [(example, label_map, MAX_SEQ_LENGTH, tokenizer, OUTPUT_MODE) for example in eval_examples]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Preparing to convert 38000 examples..\n",
      "Spawning 15 processes..\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "74086b36b1f54590b8f78178419ae431",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=38000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "process_count = cpu_count() - 1\n",
    "if __name__ ==  '__main__':\n",
    "    print(f'Preparing to convert {eval_examples_len} examples..')\n",
    "    print(f'Spawning {process_count} processes..')\n",
    "    with Pool(process_count) as p:\n",
    "        eval_features = list(tqdm_notebook(p.imap(convert_examples_to_features.convert_example_to_feature, eval_examples_for_processing), total=eval_examples_len))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)\n",
    "all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)\n",
    "all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "if OUTPUT_MODE == \"classification\":\n",
    "    all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)\n",
    "elif OUTPUT_MODE == \"regression\":\n",
    "    all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)\n",
    "\n",
    "# Run prediction for full data\n",
    "eval_sampler = SequentialSampler(eval_data)\n",
    "eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=EVAL_BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:pytorch_pretrained_bert.modeling:loading archive file cache/yelp.tar.gz\n",
      "INFO:pytorch_pretrained_bert.modeling:extracting archive file cache/yelp.tar.gz to temp dir C:\\Users\\chatu\\AppData\\Local\\Temp\\tmpatrcxjr9\n",
      "INFO:pytorch_pretrained_bert.modeling:Model config {\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"vocab_size\": 28996\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Load pre-trained model (weights)\n",
    "model = BertForSequenceClassification.from_pretrained(CACHE_DIR + BERT_MODEL, cache_dir=CACHE_DIR, num_labels=len(label_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BertForSequenceClassification(\n",
       "  (bert): BertModel(\n",
       "    (embeddings): BertEmbeddings(\n",
       "      (word_embeddings): Embedding(28996, 768, padding_idx=0)\n",
       "      (position_embeddings): Embedding(512, 768)\n",
       "      (token_type_embeddings): Embedding(2, 768)\n",
       "      (LayerNorm): BertLayerNorm()\n",
       "      (dropout): Dropout(p=0.1)\n",
       "    )\n",
       "    (encoder): BertEncoder(\n",
       "      (layer): ModuleList(\n",
       "        (0): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (1): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (2): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (3): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (4): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (5): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (6): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (7): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (8): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (9): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (10): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (11): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (pooler): BertPooler(\n",
       "      (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "      (activation): Tanh()\n",
       "    )\n",
       "  )\n",
       "  (dropout): Dropout(p=0.1)\n",
       "  (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:***** Eval results *****\n",
      "INFO:root:  task = yelp\n",
      "INFO:root:  mcc = 0.9142765922875774\n",
      "INFO:root:  tp = 18134\n",
      "INFO:root:  tn = 18237\n",
      "INFO:root:  fp = 763\n",
      "INFO:root:  fn = 866\n",
      "INFO:root:  eval_loss = 0.11229695742577314\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "eval_loss = 0\n",
    "nb_eval_steps = 0\n",
    "preds = []\n",
    "\n",
    "for input_ids, input_mask, segment_ids, label_ids in tqdm_notebook(eval_dataloader, desc=\"Evaluating\"):\n",
    "    input_ids = input_ids.to(device)\n",
    "    input_mask = input_mask.to(device)\n",
    "    segment_ids = segment_ids.to(device)\n",
    "    label_ids = label_ids.to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        logits = model(input_ids, segment_ids, input_mask, labels=None)\n",
    "\n",
    "    # create eval loss and other metric required by the task\n",
    "    if OUTPUT_MODE == \"classification\":\n",
    "        loss_fct = CrossEntropyLoss()\n",
    "        tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))\n",
    "    elif OUTPUT_MODE == \"regression\":\n",
    "        loss_fct = MSELoss()\n",
    "        tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))\n",
    "\n",
    "    eval_loss += tmp_eval_loss.mean().item()\n",
    "    nb_eval_steps += 1\n",
    "    if len(preds) == 0:\n",
    "        preds.append(logits.detach().cpu().numpy())\n",
    "    else:\n",
    "        preds[0] = np.append(\n",
    "            preds[0], logits.detach().cpu().numpy(), axis=0)\n",
    "\n",
    "eval_loss = eval_loss / nb_eval_steps\n",
    "preds = preds[0]\n",
    "if OUTPUT_MODE == \"classification\":\n",
    "    preds = np.argmax(preds, axis=1)\n",
    "elif OUTPUT_MODE == \"regression\":\n",
    "    preds = np.squeeze(preds)\n",
    "result = compute_metrics(TASK_NAME, all_label_ids.numpy(), preds)\n",
    "\n",
    "result['eval_loss'] = eval_loss\n",
    "\n",
    "output_eval_file = os.path.join(REPORTS_DIR, \"eval_results.txt\")\n",
    "with open(output_eval_file, \"w\") as writer:\n",
    "    logger.info(\"***** Eval results *****\")\n",
    "    for key in (result.keys()):\n",
    "        logger.info(\"  %s = %s\", key, str(result[key]))\n",
    "        writer.write(\"%s = %s\\n\" % (key, str(result[key])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
